/*
 * AIEngine a new generation network intrusion detection system.
 *
 * Copyright (C) 2013-2023  Luis Campo Giralte
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Library General Public
 * License as published by the Free Software Foundation; either
 * version 2 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Library General Public License for more details.
 *
 * You should have received a copy of the GNU Library General Public
 * License along with this library; if not, write to the
 * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
 * Boston, MA  02110-1301, USA.
 *
 * Written by Luis Campo Giralte <luis.camp0.2009@gmail.com> 
 *
 */
#ifndef SRC_PROTOCOLS_DTLS_DTLSPROTOCOL_H_
#define SRC_PROTOCOLS_DTLS_DTLSPROTOCOL_H_

#ifdef HAVE_CONFIG_H
#include <config.h>
#endif

#include "Protocol.h"
#include <arpa/inet.h>
#include "flow/FlowManager.h"
#include "Cache.h"
#include "DTLSInfo.h"

namespace aiengine {

// Minium DTLS header
struct dtls_header {
        uint8_t         type;           /* DTLS record type */
        uint16_t        version;        /* DTLS version (major/minor) */
        uint16_t        epoch;          /* Epoch */
        uint32_t        seq1;            /* Sequence number */
        uint16_t        seq0;            /* Sequence number */
        uint16_t        length;         /* Length of data */
        uint8_t         data[0];
} __attribute__((packed));

// Record_type is the same as TLS1.2
#define DTLS_CT_HANDSHAKE               22
#define DTLS_CT_ALERT                   21
#define DTLS_CT_CHANGE_CIPHER_SPEC      20
#define DTLS_CT_APPLICATION_DATA        23

// Record types of the ssl_handshake_record
#define DTLS_MT_HELLO_REQUEST            0   //(x'00')
#define DTLS_MT_CLIENT_HELLO             1   //(x'01')
#define DTLS_MT_SERVER_HELLO             2   //(x'02')
#define DTLS_MT_HELLO_VERIFY             3   // 0x03
#define DTLS_MT_NEW_SESSION_TICKET       4   //(x'04')
#define DTLS_MT_CERTIFICATE             11   //(x'0B')
#define DTLS_MT_SERVER_KEY_EXCHANGE     12   // (x'0C')
#define DTLS_MT_CERTIFICATE_REQUEST     13   // (x'0D')
#define DTLS_MT_SERVER_DONE             14   // (x'0E')
#define DTLS_MT_CERTIFICATE_VERIFY      15   // (x'0F')
#define DTLS_MT_CLIENT_KEY_EXCHANGE     16   // (x'10')
#define DTLS_MT_FINISHED                20   // (x'14')

class DTLSProtocol: public Protocol {
public:
	explicit DTLSProtocol():
		Protocol("DTLS", IPPROTO_UDP) {}
    	virtual ~DTLSProtocol() {}

	static constexpr uint16_t header_size = sizeof(dtls_header);

	uint16_t getId() const override { return 0x0000; }
	uint16_t getHeaderSize() const override { return header_size; }

	// Condition for say that a packet is DTLS 
	bool check(const Packet &packet) override; 
        void processFlow(Flow *flow) override;
        bool processPacket(Packet& packet) override { return true; } 

	void statistics(std::basic_ostream<char>& out, int level, int32_t limit) const override;
	void statistics(Json &out, int level) const override;

	void releaseCache() override; 

	void setHeader(const uint8_t *raw_packet) override { 

		header_ = reinterpret_cast <const dtls_header*> (raw_packet);
	}

        void increaseAllocatedMemory(int value) override;
        void decreaseAllocatedMemory(int value) override;

#if defined(STAND_ALONE_TEST) || defined(TESTING)

        uint32_t getTotalHandshakes() const { return total_handshakes_; }
        uint32_t getTotalEncryptedHandshakes() const { return total_encrypted_handshakes_; }
        uint32_t getTotalAlerts() const { return total_alerts_; }
        uint32_t getTotalChangeCipherSpecs() const { return total_change_cipher_specs_; }
        uint32_t getTotalDatas() const { return total_data_; }

        uint32_t getTotalClientHellos() const { return total_client_hellos_; }
	uint32_t getTotalHelloVerifyRequests() const { return total_hello_verifies_requests_; }
        uint32_t getTotalServerHellos() const { return total_server_hellos_; }
        uint32_t getTotalCertificates() const { return total_certificates_; }
        uint32_t getTotalCertificateRequests() const { return total_certificate_requests_; }
        uint32_t getTotalCertificateVerifies() const { return total_certificate_verifies_; }
        uint32_t getTotalServerDones() const { return total_server_dones_; }
        uint32_t getTotalServerKeyExchanges() const {  return total_server_key_exchanges_; }
        uint32_t getTotalClientKeyExchanges() const {  return total_client_key_exchanges_; }
        uint32_t getTotalHandshakeFinishes() const { return total_handshake_finishes_; }
        uint32_t getTotalNewSessionTickets() const { return total_new_session_tickets_; }
        uint32_t getTotalRecords() const { return total_records_; }
#endif

	void setFlowManager(FlowManagerPtrWeak flow_mng) { flow_mng_ = flow_mng; }

	uint64_t getCurrentUseMemory() const override;
	uint64_t getAllocatedMemory() const override;
	uint64_t getTotalAllocatedMemory() const override;

        void setDynamicAllocatedMemory(bool value) override;
        bool isDynamicAllocatedMemory() const override;

	CounterMap getCounters() const override; 
	void resetCounters() override; 

	void releaseFlowInfo(Flow *flow) override;
	Flow *getCurrentFlow() const { return current_flow_; }
private:
	const dtls_header *header_ = nullptr;

	// Some statistics 
        uint32_t total_handshakes_ = 0;
        uint32_t total_encrypted_handshakes_ = 0;
        uint32_t total_alerts_ = 0; 
        uint32_t total_change_cipher_specs_ = 0;
        uint32_t total_data_ = 0;

	uint32_t total_client_hellos_ = 0;
	uint32_t total_hello_verifies_requests_ = 0;
        uint32_t total_server_hellos_ = 0;
        uint32_t total_certificates_ = 0;
        uint32_t total_certificate_requests_ = 0;
        uint32_t total_certificate_verifies_ = 0;
        uint32_t total_server_dones_ = 0;
        uint32_t total_server_key_exchanges_ = 0;
        uint32_t total_client_key_exchanges_ = 0;
        uint32_t total_handshake_finishes_ = 0;
        uint32_t total_new_session_tickets_ = 0;
        uint32_t total_records_ = 0;

	Cache<DTLSInfo>::CachePtr info_cache_ = Cache<DTLSInfo>::CachePtr(new Cache<DTLSInfo>("DTLS Info cache"));

	FlowManagerPtrWeak flow_mng_ = FlowManagerPtrWeak();

        Flow *current_flow_ = nullptr;
};

typedef std::shared_ptr<DTLSProtocol> DTLSProtocolPtr;

} // namespace aiengine

#endif  // SRC_PROTOCOLS_DTLS_DTLSPROTOCOL_H_
